package edu.berkeley.nlp.syntax;

import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.MapFactory;
import edu.berkeley.nlp.util.MyMethod;
import edu.berkeley.nlp.util.Pair;

import java.io.Serializable;
import java.util.*;

/**
 * Represent linguistic trees, with each node consisting of a label and a list
 * of children.
 * 
 * @author Dan Klein
 * 
 * Added function to get a map of subtrees to constituents.
 */
public class Tree<L> implements Serializable, Comparable<Tree<L>>, Iterable<Tree<L>> {

	private static final long serialVersionUID = 1L;

	L label;

	List<Tree<L>> children;

  public void setChild(int i, Tree<L> child) {
    children.set(i,child);
  }

	public void setChildren(List<Tree<L>> c) {
		this.children = c;
	}

	public List<Tree<L>> getChildren() {
		return children;
	}

  public Tree<L> getChild(int i) {
    return children.get(i);
  }

	public L getLabel() {
		return label;
	}

	public boolean isLeaf() {
		return getChildren().isEmpty();
	}

	public boolean isPreTerminal() {
		return getChildren().size() == 1 && getChildren().get(0).isLeaf();
	}

	public List<L> getYield() {
		List<L> yield = new ArrayList<L>();
		appendYield(this, yield);
		return yield;
	}

	public Collection<Constituent<L>> getConstituentCollection() {
		Collection<Constituent<L>> constituents = new ArrayList<Constituent<L>>();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

	/**
	 * John: I changed this from a hash map because it was broken as a HashMap.
	 */
	public Map<Tree<L>, Constituent<L>> getConstituents() {
		Map<Tree<L>, Constituent<L>> constituents = new IdentityHashMap<Tree<L>, Constituent<L>>();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

  public Map<Pair<Integer,Integer>, List<Tree<L>>> getSpanMap() {
    Map<Tree<L>, Constituent<L>> cMap = getConstituents();
    Map<Pair<Integer,Integer>, List<Tree<L>>> spanMap = new HashMap();
    for (Map.Entry<Tree<L>, Constituent<L>> entry : cMap.entrySet()) {
      Tree<L> t = entry.getKey();
      Constituent<L> c = entry.getValue();
      Pair<Integer,Integer> span = Pair.newPair(c.getStart(),c.getEnd()+1);
      CollectionUtils.addToValueList(spanMap,span,t);
    }
    for (List<Tree<L>> trees : spanMap.values()) {
      Collections.sort(trees,new Comparator<Tree<L>>() {
        public int compare(Tree<L> t1, Tree<L> t2) {
          return t2.getDepth()-t1.getDepth();
      }});          
    }
    return spanMap;
  }

	public Map<Tree<L>, Constituent<L>> getConstituents(MapFactory mf) {
		Map<Tree<L>, Constituent<L>> constituents = mf.buildMap();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

	private static <L> int appendConstituent(Tree<L> tree,
			Map<Tree<L>, Constituent<L>> constituents, int index) {
		if (tree.isLeaf()) {
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, index);
			constituents.put(tree, c);
			return 1; // Length of a leaf constituent
		} else {
			int nextIndex = index;
			for (Tree<L> kid : tree.getChildren()) {
				nextIndex += appendConstituent(kid, constituents, nextIndex);
			}
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, nextIndex - 1);
			constituents.put(tree, c);
			return nextIndex - index; // Length of a leaf constituent
		}
	}

	private static <L> int appendConstituent(Tree<L> tree,
			Collection<Constituent<L>> constituents, int index) {
		if (tree.isLeaf() || tree.isPreTerminal()) {
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, index);
			constituents.add(c);
			return 1; // Length of a leaf constituent
		} else {
			int nextIndex = index;
			for (Tree<L> kid : tree.getChildren()) {
				nextIndex += appendConstituent(kid, constituents, nextIndex);
			}
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, nextIndex - 1);
			constituents.add(c);
			return nextIndex - index; // Length of a leaf constituent
		}
	}

	private static <L> void appendNonTerminals(Tree<L> tree, List<Tree<L>> yield) {
	  	if (tree.isLeaf()) {
	  	
	  		return;
	  	}
	  	yield.add(tree);
	    for (Tree<L> child : tree.getChildren()) {
	      appendNonTerminals(child, yield);
	    }
	}
	  
	public List<Tree<L>> getTerminals() {
		List<Tree<L>> yield = new ArrayList<Tree<L>>();
		appendTerminals(this, yield);
		return yield;
	}

    public List<Tree<L>> getNonTerminals(){
	  	List<Tree<L>> yield = new ArrayList<Tree<L>>();
	  	appendNonTerminals(this, yield);
	  	return yield;
    }

	
	private static <L> void appendTerminals(Tree<L> tree, List<Tree<L>> yield) {
		if (tree.isLeaf()) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendTerminals(child, yield);
		}
	}

	/**
	 * Clone the structure of the tree. Unfortunately, the new labels are copied
	 * by reference from the current tree.
	 * 
	 * @return
	 */
	public Tree<L> shallowClone() {
		ArrayList<Tree<L>> newChildren = new ArrayList<Tree<L>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.shallowClone());
		}
		return new Tree<L>(label, newChildren);
	}

	/**
	 * Return a clone of just the root node of this tree (with no children)
	 * 
	 * @return
	 */
	public Tree<L> shallowCloneJustRoot() {

		return new Tree<L>(label);
	}

	private static <L> void appendYield(Tree<L> tree, List<L> yield) {
		if (tree.isLeaf()) {
			yield.add(tree.getLabel());
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendYield(child, yield);
		}
	}

	public List<L> getPreTerminalYield() {
		List<L> yield = new ArrayList<L>();
		appendPreTerminalYield(this, yield);
		return yield;
	}

	public List<L> getTerminalYield() {
		List<Tree<L>> terms = getTerminals();
		List<L> yield = new ArrayList<L>();
		for (Tree<L> term : terms) {
			yield.add(term.getLabel());
		}
		return yield;
	}

	public List<Tree<L>> getPreTerminals() {
		List<Tree<L>> preterms = new ArrayList<Tree<L>>();
		appendPreTerminals(this, preterms);
		return preterms;
	}

	public List<Tree<L>> getTreesOfDepth(int depth) {
		List<Tree<L>> trees = new ArrayList<Tree<L>>();
		appendTreesOfDepth(this, trees, depth);
		return trees;
	}

	private static <L> void appendPreTerminalYield(Tree<L> tree, List<L> yield) {
		if (tree.isPreTerminal()) {
			yield.add(tree.getLabel());
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendPreTerminalYield(child, yield);
		}
	}

	private static <L> void appendPreTerminals(Tree<L> tree, List<Tree<L>> yield) {
		if (tree.isPreTerminal()) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendPreTerminals(child, yield);
		}
	}

	private static <L> void appendTreesOfDepth(Tree<L> tree, List<Tree<L>> yield, int depth) {
		if (tree.getDepth() == depth) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendTreesOfDepth(child, yield, depth);
		}
	}

	public List<Tree<L>> getPreOrderTraversal() {
		ArrayList<Tree<L>> traversal = new ArrayList<Tree<L>>();
		traversalHelper(this, traversal, true);
		return traversal;
	}

	public List<Tree<L>> getPostOrderTraversal() {
		ArrayList<Tree<L>> traversal = new ArrayList<Tree<L>>();
		traversalHelper(this, traversal, false);
		return traversal;
	}

	private static <L> void traversalHelper(Tree<L> tree, List<Tree<L>> traversal,
			boolean preOrder) {
		if (preOrder) traversal.add(tree);
		for (Tree<L> child : tree.getChildren()) {
			traversalHelper(child, traversal, preOrder);
		}
		if (!preOrder) traversal.add(tree);
	}

	public int getDepth() {
		int maxDepth = 0;
		for (Tree<L> child : children) {
			int depth = child.getDepth();
			if (depth > maxDepth) maxDepth = depth;
		}
		return maxDepth + 1;
	}

  public int size() {
    int sum = 0;
    for (Tree<L> child : children) {
      sum += child.size();
    }
    return sum + 1;
  }

	public List<Tree<L>> getAtDepth(int depth) {
		List<Tree<L>> yield = new ArrayList<Tree<L>>();
		appendAtDepth(depth, this, yield);
		return yield;
	}

	private static <L> void appendAtDepth(int depth, Tree<L> tree, List<Tree<L>> yield) {
		if (depth < 0) return;
		if (depth == 0) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendAtDepth(depth - 1, child, yield);
		}
	}

	public void setLabel(L label) {
		this.label = label;
	}

	@Override
	public String toString() {
		StringBuilder sb = new StringBuilder();
		toStringBuilder(sb);
		return sb.toString();
	}

	public void toStringBuilder(StringBuilder sb) {
		if (!isLeaf()) sb.append('(');
		if (getLabel() != null) {
			sb.append(getLabel());
		}
		if (!isLeaf()) {
			for (Tree<L> child : getChildren()) {
				sb.append(' ');
				child.toStringBuilder(sb);
			}
			sb.append(')');
		}
	}

	/**
	 * Same as toString(), but escapes terminals like so:
	 * ( becomes -LRB-
	 * ) becomes -RRB-
	 * \ becomes -BACKSLASH- ("\" does not occur in PTB; this is our own convention)
	 * This is useful because otherwise it's hard to tell a "(" terminal from the tree's bracket
	 * structure, or tell an escaping \ from a literal.
	 */
	public String toEscapedString() {
		StringBuilder sb = new StringBuilder();
		toStringBuilderEscaped(sb);
		return sb.toString();
	}

	public void toStringBuilderEscaped(StringBuilder sb) {
		if (!isLeaf()) sb.append('(');
		if (getLabel() != null) {
			if (isLeaf()) {
				String escapedLabel = getLabel().toString();
				escapedLabel = escapedLabel.replaceAll("\\(", "-LRB-");
				escapedLabel = escapedLabel.replaceAll("\\)", "-RRB-");
				escapedLabel = escapedLabel.replaceAll("\\\\", "-BACKSLASH-");
				sb.append(escapedLabel);
			} else {
				sb.append(getLabel());
			}
		}
		if (!isLeaf()) {
			for (Tree<L> child : getChildren()) {
				sb.append(' ');
				child.toStringBuilderEscaped(sb);
			}
			sb.append(')');
		}
	}

	public Tree(L label, List<Tree<L>> children) {
		this.label = label;
		this.children = children;
	}

	public Tree(L label) {
		this.label = label;
		this.children = Collections.emptyList();
	}

	/**
	 * Get the set of all subtrees inside the tree by returning a tree rooted at
	 * each node. These are <i>not</i> copies, but all share structure. The
	 * tree is regarded as a subtree of itself.
	 * 
	 * @return the <code>Set</code> of all subtrees in the tree.
	 */
	public Set<Tree<L>> subTrees() {
		return (Set<Tree<L>>) subTrees(new HashSet<Tree<L>>());
	}

	/**
	 * Get the list of all subtrees inside the tree by returning a tree rooted
	 * at each node. These are <i>not</i> copies, but all share structure. The
	 * tree is regarded as a subtree of itself.
	 * 
	 * @return the <code>List</code> of all subtrees in the tree.
	 */
	public List<Tree<L>> subTreeList() {
		return (List<Tree<L>>) subTrees(new ArrayList<Tree<L>>());
	}

	/**
	 * Add the set of all subtrees inside a tree (including the tree itself) to
	 * the given <code>Collection</code>.
	 * 
	 * @param n
	 *            A collection of nodes to which the subtrees will be added
	 * @return The collection parameter with the subtrees added
	 */
	public Collection<Tree<L>> subTrees(Collection<Tree<L>> n) {
		n.add(this);
		List<Tree<L>> kids = getChildren();
		for (Tree<L> kid : kids) {
			kid.subTrees(n);
		}
		return n;
	}

	/**
	 * Returns an iterator over the nodes of the tree. This method implements
	 * the <code>iterator()</code> method required by the
	 * <code>Collections</code> interface. It does a preorder (children after
	 * node) traversal of the tree. (A possible extension to the class at some
	 * point would be to allow different traversal orderings via variant
	 * iterators.)
	 * 
	 * @return An iterator over the nodes of the tree
	 */
	public Iterator<Tree<L>> iterator() {
		return new TreeIterator();
	}

	private class TreeIterator implements Iterator<Tree<L>> {

		private List<Tree<L>> treeStack;

		private TreeIterator() {
			treeStack = new ArrayList<Tree<L>>();
			treeStack.add(Tree.this);
		}

		public boolean hasNext() {
			return (!treeStack.isEmpty());
		}

		public Tree<L> next() {
			int lastIndex = treeStack.size() - 1;
			Tree<L> tr = treeStack.remove(lastIndex);
			List<Tree<L>> kids = tr.getChildren();
			// so that we can efficiently use one List, we reverse them
			for (int i = kids.size() - 1; i >= 0; i--) {
				treeStack.add(kids.get(i));
			}
			return tr;
		}

		/**
		 * Not supported
		 */
		public void remove() {
			throw new UnsupportedOperationException();
		}

	}

	/**
	 * Applies a transformation to all labels in the tree and returns the
	 * resulting tree.
	 * 
	 * @param <O>
	 *            Output type of the transformation
	 * @param trans
	 *            The transformation to apply
	 * @return Transformed tree
	 */
	public <O> Tree<O> transformNodes(MyMethod<L, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodes(trans));
		}
		return new Tree<O>(trans.call(label), newChildren);
	}

	/**
	 * Applies a transformation to all nodes in the tree and returns the
	 * resulting tree. Different from <code>transformNodes</code> in that you
	 * get the full node and not just the label
	 * 
	 * @param <O>
	 * @param trans
	 * @return
	 */
	public <O> Tree<O> transformNodesUsingNode(MyMethod<Tree<L>, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		O newLabel = trans.call(this);
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodesUsingNode(trans));
		}
		return new Tree<O>(newLabel, newChildren);
	}

	public <O> Tree<O> transformNodesUsingNodePostOrder(MyMethod<Tree<L>, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodesUsingNode(trans));
		}
		O newLabel = trans.call(this);
		return new Tree<O>(newLabel, newChildren);
	}

	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + ((label == null) ? 0 : label.hashCode());
		for (Tree<L> child : children) {
			result = prime * result + ((child == null) ? 0 : child.hashCode());
		}
		return result;
	}

	@Override
	public boolean equals(Object obj) {
		if (this == obj) return true;
		if (obj == null) return false;
		if (getClass() != obj.getClass()) return false;
		if (!(obj instanceof Tree)) return false;
		final Tree<L> other = (Tree<L>) obj;
		if (!this.label.equals(other.label)) return false;
		if (this.getChildren().size() != other.getChildren().size()) return false;
		for (int i = 0; i < getChildren().size(); ++i) {

			if (!getChildren().get(i).equals(other.getChildren().get(i))) return false;
		}
		return true;

	}

	public int compareTo(Tree<L> o) {
		if (!(o.getLabel() instanceof Comparable && getLabel() instanceof Comparable))
			throw new IllegalArgumentException("Tree labels are not comparable");
		int cmp = ((Comparable) o.getLabel()).compareTo(getLabel());
		if (cmp != 0) return cmp;
		int cmp2 = Double.compare(this.getChildren().size(), o.getChildren().size());
		if (cmp2 != 0) return cmp2;
		for (int i = 0; i < getChildren().size(); ++i) {

			int cmp3 = getChildren().get(i).compareTo(o.getChildren().get(i));
			if (cmp3 != 0) return cmp3;
		}
		return 0;

	}

	public boolean isPhrasal() {
		return getYield().size() > 1;
	}

	public Constituent<L> getLeastCommonAncestorConstituent(int i, int j) {
		final List<L> yield = getYield();
		final Constituent<L> leastCommonAncestorConstituentHelper = getLeastCommonAncestorConstituentHelper(
				this, 0, yield.size(), i, j);

		return leastCommonAncestorConstituentHelper;
	}

	public Tree<L> getTopTreeForSpan(int i, int j) {
		final List<L> yield = getYield();
		return getTopTreeForSpanHelper(this, 0, yield.size(), i, j);
	}

	private static <L> Tree<L> getTopTreeForSpanHelper(Tree<L> tree, int start, int end,
			int i, int j) {

		assert i <= j;
		if (start == i && end == j) {
			assert tree.getLabel().toString().matches("\\w+");
			return tree;
		}

		Queue<Tree<L>> queue = new LinkedList<Tree<L>>();
		queue.addAll(tree.getChildren());
		int currStart = start;
		while (!queue.isEmpty()) {
			Tree<L> remove = queue.remove();
			List<L> currYield = remove.getYield();
			final int currEnd = currStart + currYield.size();
			if (currStart <= i && currEnd >= j)
				return getTopTreeForSpanHelper(remove, currStart, currEnd, i, j);
			currStart += currYield.size();
		}
		return null;
	}

	private static <L> Constituent<L> getLeastCommonAncestorConstituentHelper(Tree<L> tree,
			int start, int end, int i, int j) {

		if (start == i && end == j) return new Constituent<L>(tree.getLabel(), start, end);

		Queue<Tree<L>> queue = new LinkedList<Tree<L>>();
		queue.addAll(tree.getChildren());
		int currStart = start;
		while (!queue.isEmpty()) {
			Tree<L> remove = queue.remove();
			List<L> currYield = remove.getYield();
			final int currEnd = currStart + currYield.size();
			if (currStart <= i && currEnd >= j) {
				final Constituent<L> leastCommonAncestorConstituentHelper = getLeastCommonAncestorConstituentHelper(
						remove, currStart, currEnd, i, j);
				if (leastCommonAncestorConstituentHelper != null) return leastCommonAncestorConstituentHelper;
				else break;
			}
			currStart += currYield.size();
		}
		return new Constituent<L>(tree.getLabel(), start, end);
	}

	  public boolean hasUnariesOtherThanRoot()
	  {
	  	assert children.size() == 1;
	  	return hasUnariesHelper(children.get(0));
	  	
	  }
	  
	  private boolean hasUnariesHelper(Tree<L> tree)
	  {
	  	if (tree.isPreTerminal())
	  		return false;
	  	if (tree.getChildren().size() == 1)
	  		return true;
	  	for (Tree<L> child : tree.getChildren())
	  	{
	  		if (hasUnariesHelper(child))
	  			return true;
	  	}
	  	return false;
	  }
	  
	  public boolean hasUnaryChain(){
	  	return hasUnaryChainHelper(this, false);
	  }
	  	
	  private boolean hasUnaryChainHelper(Tree<L> tree, boolean unaryAbove){
	  	boolean result = false;
			if (tree.getChildren().size()==1){
				if (unaryAbove) return true;
				else if (tree.getChildren().get(0).isPreTerminal()) return false;
				else return hasUnaryChainHelper(tree.getChildren().get(0), true);
	  	}
	  	else {
	  		for (Tree<L> child : tree.getChildren()){
	  			if (!child.isPreTerminal()) 
	  				result = result || hasUnaryChainHelper(child,false);
	  		}
	  	}
	  	return result;
	  }
	  
	  public void removeUnaryChains(){
	  	removeUnaryChainHelper(this, null);
	  }
	  	
	  private void removeUnaryChainHelper(Tree<L> tree, Tree<L> parent){
	  	if (tree.isLeaf()) return;
	  	if (tree.getChildren().size()==1&&!tree.isPreTerminal()){
				if (parent!=null) {
					tree = tree.getChildren().get(0);
					parent.getChildren().set(0, tree);
					removeUnaryChainHelper(tree, parent);
				}
				else 
					removeUnaryChainHelper(tree.getChildren().get(0), tree);
	  	}
	  	else {
	  		for (Tree<L> child : tree.getChildren()){
	  			if (!child.isPreTerminal()) 
	  				removeUnaryChainHelper(child,null);
	  		}
	  	}
	  }

}
